{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1eae7403-f458-4f55-a557-4e045bd6f679",
   "metadata": {
    "id": "1eae7403-f458-4f55-a557-4e045bd6f679"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import models_vit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4573e6be-935a-4106-8c06-e467552b0e3d",
   "metadata": {
    "id": "4573e6be-935a-4106-8c06-e467552b0e3d"
   },
   "outputs": [],
   "source": [
    "\n",
    "imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
    "imagenet_std = np.array([0.229, 0.224, 0.225])\n",
    "\n",
    "\n",
    "def prepare_model(chkpt_dir, arch='vit_large_patch16'):\n",
    "    # build model\n",
    "    model = models_vit.__dict__[arch](\n",
    "        img_size=224,\n",
    "        num_classes=5,\n",
    "        drop_path_rate=0,\n",
    "        global_pool=True,\n",
    "    )\n",
    "    # load model\n",
    "    checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
    "    msg = model.load_state_dict(checkpoint['model'], strict=False)\n",
    "    return model\n",
    "\n",
    "def run_one_image(img, model):\n",
    "    \n",
    "    x = torch.tensor(img)\n",
    "    x = x.unsqueeze(dim=0)\n",
    "    x = torch.einsum('nhwc->nchw', x)\n",
    "    \n",
    "    x = x.to(device, non_blocking=True)\n",
    "    latent = model.forward_features(x.float())\n",
    "    latent = torch.squeeze(latent)\n",
    "    \n",
    "    return latent\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b7e691d-93d2-439f-91d6-c22716a897b5",
   "metadata": {
    "id": "8b7e691d-93d2-439f-91d6-c22716a897b5"
   },
   "source": [
    "### Load a pre-trained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
   "metadata": {
    "id": "fd2d7da9-f75c-4b27-a84b-6d1247f73a7d",
    "outputId": "a1f0dba1-2cae-484b-ad84-8b00bc7628aa"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded.\n"
     ]
    }
   ],
   "source": [
    "# download pre-trained RETFound \n",
    "\n",
    "chkpt_dir = './RETFound_cfp.pth'\n",
    "model_ = prepare_model(chkpt_dir, 'vit_large_patch16')\n",
    "\n",
    "device = torch.device('cuda')\n",
    "model_.to(device)\n",
    "print('Model loaded.')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6",
   "metadata": {
    "id": "7d15a0a7-c093-439a-9a4d-c37ce0c0eaa6"
   },
   "source": [
    "### Load images and save latent feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "27755296-05cc-4344-90de-a8ab3878f485",
   "metadata": {
    "id": "27755296-05cc-4344-90de-a8ab3878f485",
    "outputId": "34c3c12a-0a17-44fe-b72a-cef6eecabc70",
    "tags": []
   },
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'Your data path'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mFileNotFoundError\u001B[0m                         Traceback (most recent call last)",
      "\u001B[0;32m/tmp/ipykernel_16866/3238108902.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m      1\u001B[0m \u001B[0;31m# get image list\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      2\u001B[0m \u001B[0mdata_path\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m'Your data path'\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 3\u001B[0;31m \u001B[0mimg_list\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mos\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlistdir\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdata_path\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m      4\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      5\u001B[0m \u001B[0mname_list\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'Your data path'"
     ]
    }
   ],
   "source": [
    "# get image list\n",
    "data_path = 'Your data path'\n",
    "img_list = os.listdir(data_path)\n",
    "\n",
    "name_list = []\n",
    "feature_list = []\n",
    "model_.eval()\n",
    "\n",
    "for i in img_list:\n",
    "    img = Image.open(os.path.join(data_path, i))\n",
    "    img = img.resize((224, 224))\n",
    "    img = np.array(img) / 255.\n",
    "\n",
    "    assert img.shape == (224, 224, 3)\n",
    "\n",
    "    # normalize by mean and sd\n",
    "    # can use customised mean and sd for your data\n",
    "    img = img - imagenet_mean\n",
    "    img = img / imagenet_std\n",
    "    \n",
    "    latent_feature = run_one_image(img, model_)\n",
    "    \n",
    "    name_list.append(i)\n",
    "    feature_list.append(latent_feature.detach().cpu().numpy())\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a365ec24-8e29-485e-83b5-5ac1d02945bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_csv = pd.DataFrame({'Name':name_list, 'Latent_feature':feature_list})\n",
    "latent_csv.to_csv('Feature_latent.csv', index = False, encoding='utf8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e8bd5e6-5780-420d-9d4c-96025b265668",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "environment": {
   "kernel": "python3",
   "name": "common-cu110.m91",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/base-cu110:m91"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
